{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "# Predictions with Tensor RT\n",
    "In this notebook we show how to use the Tensor RT engine that we created after training our model and serialized to disk. \n",
    "\n",
    "### Imports\n",
    "We import here the packages that are required to run this notebook"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensorrt.lite import Engine"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from PIL import Image\n",
    "import numpy as np\n",
    "import os\n",
    "import functools\n",
    "import time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "PLAN_single = '/home/model/keras_vgg19_b1_fp32.engine'  # engine filename for batch size 1\n",
    "PLAN_half = '/home/model/keras_vgg19_b1_fp16.engine'\n",
    "IMAGE_DIR = '/home/data/val/roses'\n",
    "BATCH_SIZE = 1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Utility functions\n",
    "We define here a few utility functions. These functions are used to \n",
    "* Analyze the prediction\n",
    "* Convert image to a format that is identical to the format used durign training\n",
    "* Organize the images into a list of numpy array\n",
    "* Time the compute time of a function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def analyze(output_data):\n",
    "    LABELS=[\"daisy\", \"dandelion\", \"roses\", \"sunflowers\", \"tulips\"]\n",
    "    output = output_data.reshape(-1, len(LABELS))\n",
    "    \n",
    "    top_classes = [LABELS[idx] for idx in np.argmax(output, axis=1)]\n",
    "    top_classes_prob = np.amax(output, axis=1)  \n",
    "\n",
    "    return top_classes, top_classes_prob\n",
    "\n",
    "\n",
    "def image_to_np_CHW(image): \n",
    "    return np.asarray(\n",
    "        image.resize(\n",
    "            (224, 224), \n",
    "            Image.ANTIALIAS\n",
    "        )).transpose([2,0,1]).astype(np.float32)\n",
    "\n",
    "\n",
    "def load_and_preprocess_images():\n",
    "    file_list = [f for f in os.listdir(IMAGE_DIR) if os.path.isfile(os.path.join(IMAGE_DIR, f))]\n",
    "    images_trt = []\n",
    "    for f in file_list:\n",
    "        images_trt.append(image_to_np_CHW(Image.open(os.path.join(IMAGE_DIR, f))))\n",
    "        \n",
    "    images_trt = np.stack(images_trt)\n",
    "    \n",
    "    num_batches = int(len(images_trt) / BATCH_SIZE)\n",
    "    \n",
    "    images_trt = np.reshape(images_trt[0:num_batches * BATCH_SIZE], [\n",
    "        num_batches, \n",
    "        BATCH_SIZE, \n",
    "        images_trt.shape[1],\n",
    "        images_trt.shape[2],\n",
    "        images_trt.shape[3]\n",
    "    ]) \n",
    "    \n",
    "    return images_trt\n",
    "\n",
    "\n",
    "def timeit(func):\n",
    "    @functools.wraps(func)\n",
    "    def newfunc(*args, **kwargs):\n",
    "        startTime = time.time()\n",
    "        retargs = func(*args, **kwargs)\n",
    "        elapsedTime = time.time() - startTime\n",
    "        print('function [{}] finished in {} ms'.format(\n",
    "            func.__name__, int(elapsedTime * 1000)))\n",
    "        return retargs\n",
    "    return newfunc"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Prepare TensorRT Engine\n",
    "Here we simply load the TRT engine such that we can do inference. We can also attach a function (utility function) to postprocess the outputs before returning them (in this case we use the function analyze)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_TRT_engine(plan):\n",
    "    engine = Engine(PLAN=plan, postprocessors={\"dense_2/Softmax\":analyze})   \n",
    "    return engine\n",
    "\n",
    "engine_single = load_TRT_engine(PLAN_single)\n",
    "engine_half = load_TRT_engine(PLAN_half)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load all data\n",
    "Here we load all the test data from the directory specified above in \"IMAGE_DIR\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "images_trt = load_and_preprocess_images()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Prepare function to do inference with Tensor RT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@timeit\n",
    "def infer_all_images_trt(engine):\n",
    "    results = []\n",
    "    for image in images_trt:\n",
    "        result = engine.infer(image) \n",
    "        results.append(result)\n",
    "    return results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# DO inference with TRT\n",
    "results_trt_single = infer_all_images_trt(engine_single)\n",
    "results_trt_half = infer_all_images_trt(engine_half)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(len(results_trt_single)):\n",
    "    plt.imshow(images_trt[i, 0, 0],  cmap='gray')\n",
    "    plt.show()\n",
    "    print results_trt_single[i][0][0][0]\n",
    "    print results_trt_half[i][0][0][0]"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
